"""Common utilities for creating J2WASM targets and providers."""

load("@bazel_skylib//lib:structs.bzl", "structs")
load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
load(":j2cl_common.bzl", "J2CL_TOOLCHAIN_ATTRS", "j2cl_common")
load(":provider.bzl", "J2wasmInfo")

def _compile(
        ctx,
        deps = [],
        exports = [],
        artifact_suffix = "",
        **kwargs):
    """Performs compilation for all feature sets.

    Args:
      deps: List of dependency providers. This should be a list of `J2wasmInfo` and `JsInfo`s.
      exports: List of export providers. This should be a list of `J2wasmInfo` and `JsInfo`s.
      artifact_suffix: The artifact suffix for the default feature set. Other feature sets will add
        an additional suffix.

    Returns:
      A J2wasmInfo provider that contains a map from each feature set to the compilation struct
      for that compilation.
    """

    build_setting_feature_set = ctx.attr._feature_set[BuildSettingInfo].value

    def _compile_for_feature_set(feature_set):
        # Reduce redundancy by returning the identical compilation info if the requested feature set
        # is the same as the build setting or is not present.
        naming_feature_set = feature_set
        if feature_set == build_setting_feature_set or feature_set == J2WASM_FEATURE_SET.DEFAULT:
            naming_feature_set = ""
            feature_set = build_setting_feature_set

        feature_set_suffix = "-%s" % naming_feature_set if naming_feature_set else ""
        feature_artifact_suffix = artifact_suffix + feature_set_suffix

        # At this point, deps/exports are J2wasmInfo or JsInfo. j2cl_common.compile does not accept
        # a J2wasmInfo, so we must first extract the feature-specific J2clInfo from it.
        feature_deps = _extract_j2cl_infos(deps, feature_set)
        feature_exports = _extract_j2cl_infos(exports, feature_set)

        internal_transpiler_flags = {}
        if feature_set == J2WASM_FEATURE_SET.CUSTOM_DESCRIPTORS:
            internal_transpiler_flags["experimentalEnableWasmCustomDescriptors"] = True
        elif feature_set == J2WASM_FEATURE_SET.CUSTOM_DESCRIPTORS_JSINTEROP:
            internal_transpiler_flags["experimentalEnableWasmCustomDescriptorsJsInterop"] = True

        return _compile_feature_set(
            ctx = ctx,
            deps = feature_deps,
            exports = feature_exports,
            artifact_suffix = feature_artifact_suffix,
            internal_transpiler_flags = internal_transpiler_flags,
            **kwargs
        )

    return _create_j2wasm_feature_set_provider(_compile_for_feature_set)

def _extract_j2cl_infos(deps, feature_set):
    return [
        d._private_.feature_set_map[feature_set] if hasattr(d, "_is_j2wasm_provider") else d
        for d in deps
    ]

def _compile_feature_set(
        ctx,
        name,
        srcs = [],
        deps = [],
        exports = [],
        plugins = [],
        exported_plugins = [],
        javac_opts = [],
        internal_transpiler_flags = {},
        artifact_suffix = ""):
    j2cl_provider = j2cl_common.compile(
        ctx = ctx,
        srcs = srcs,
        deps = deps,
        exports = exports,
        plugins = plugins,
        exported_plugins = exported_plugins,
        backend = "WASM_MODULAR",
        javac_opts = javac_opts + DEFAULT_J2WASM_JAVAC_OPTS,
        artifact_suffix = artifact_suffix,
        internal_transpiler_flags = internal_transpiler_flags,
    )

    return _create_j2cl_provider(
        j2cl_provider,
        deps + exports,
    )

def _create_j2cl_provider(j2cl_provider, deps):
    j2wasm_deps = [d for d in deps if hasattr(d, "_is_j2cl_provider")]

    # The output_js could be "None" if there are no sources.
    output = j2cl_provider._private_.output_js

    # TODO(b/445759849): Move transitive_modules into J2clInfo itself and remove this struct.
    return struct(
        _private_ = struct(
            java_info = j2cl_provider._private_.java_info,
            js_info = j2cl_provider._private_.js_info,
            output = output,
            transitive_modules = depset(
                [output] if output else [],
                transitive = [d._private_.transitive_modules for d in j2wasm_deps],
                order = "postorder",
            ),
        ),
        _is_j2cl_provider = 1,
    )

def _create_j2wasm_feature_set_provider(mapping_fn):
    return J2wasmInfo(
        _private_ = struct(
            feature_set_map = {
                feature_set: mapping_fn(feature_set)
                for feature_set in J2WASM_FEATURE_SET_VALUES
            },
        ),
        _is_j2wasm_provider = 1,
    )

def _to_j2wasm_name(name):
    """Convert a label name used in j2cl to be used in j2wasm"""
    if name.endswith("j2cl_proto"):
        name = name[:-10]
        return "%sj2wasm_proto" % name
    if name.endswith("-j2cl"):
        name = name[:-5]
    return "%s-j2wasm" % name

j2wasm_common = struct(
    compile = _compile,
    create_j2wasm_info = _create_j2wasm_feature_set_provider,
    to_j2wasm_name = _to_j2wasm_name,
)

DEFAULT_J2WASM_JAVAC_OPTS = [
    # Preserve the private fields in turbine compilation. In order to create wasm structs for
    # Java classes, all the fields from the super classes need to be seen even if compiling
    # in a different library.
    "-XDturbine.emitPrivateFields",
    # Disable analysis for thread safety since it does not expect to see
    # private fields from dependencies.
    "-Xep:ThreadSafe:OFF",
]

J2WASM_TOOLCHAIN_ATTRS = {}
J2WASM_TOOLCHAIN_ATTRS.update(J2CL_TOOLCHAIN_ATTRS)
J2WASM_TOOLCHAIN_ATTRS.update({
    "_j2cl_java_toolchain": attr.label(
        default = Label("//jre/java:j2wasm_java_toolchain"),
    ),
    "_feature_set": attr.label(
        providers = [BuildSettingInfo],
        default = "//build_defs/internal_do_not_use:j2wasm_feature_set",
    ),
})

J2WASM_FEATURE_SET = struct(
    DEFAULT = "",
    CUSTOM_DESCRIPTORS = "custom_descriptors",
    CUSTOM_DESCRIPTORS_JSINTEROP = "custom_descriptors_jsinterop",
)

J2WASM_FEATURE_SET_VALUES = structs.to_dict(J2WASM_FEATURE_SET).values()
